/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    RBFNetwork.java
 *    Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
 *
 */
package weka.classifiers.functions;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.clusterers.MakeDensityBasedClusterer;
import weka.clusterers.SimpleKMeans;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.ClusterMembership;
import weka.filters.unsupervised.attribute.Standardize;

/**
 * <!-- globalinfo-start --> Class that implements a normalized Gaussian radial
 * basisbasis function network.<br/>
 * It uses the k-means clustering algorithm to provide the basis functions and
 * learns either a logistic regression (discrete class problems) or linear
 * regression (numeric class problems) on top of that. Symmetric multivariate
 * Gaussians are fit to the data from each cluster. If the class is nominal it
 * uses the given number of clusters per class.It standardizes all numeric
 * attributes to zero mean and unit variance.
 * <p/>
 * <!-- globalinfo-end -->
 * 
 * <!-- options-start --> Valid options are:
 * <p/>
 * 
 * <pre>
 * -B &lt;number&gt;
 *  Set the number of clusters (basis functions) to generate. (default = 2).
 * </pre>
 * 
 * <pre>
 * -S &lt;seed&gt;
 *  Set the random seed to be used by K-means. (default = 1).
 * </pre>
 * 
 * <pre>
 * -R &lt;ridge&gt;
 *  Set the ridge value for the logistic or linear regression.
 * </pre>
 * 
 * <pre>
 * -M &lt;number&gt;
 *  Set the maximum number of iterations for the logistic regression. (default -1, until convergence).
 * </pre>
 * 
 * <pre>
 * -W &lt;number&gt;
 *  Set the minimum standard deviation for the clusters. (default 0.1).
 * </pre>
 * 
 * <!-- options-end -->
 * 
 * @author Mark Hall
 * @author Eibe Frank
 * @version $Revision$
 */
public class RBFNetwork extends AbstractClassifier implements OptionHandler, WeightedInstancesHandler {

  /** for serialization */
  static final long serialVersionUID = -3669814959712675720L;

  /** The logistic regression for classification problems */
  private Logistic m_logistic;

  /** The linear regression for numeric problems */
  private LinearRegression m_linear;

  /** The filter for producing the meta data */
  private ClusterMembership m_basisFilter;

  /** Filter used for normalizing the data */
  private Standardize m_standardize;

  /** The number of clusters (basis functions to generate) */
  private int m_numClusters = 2;

  /** The ridge parameter for the logistic regression. */
  protected double m_ridge = 1e-8;

  /** The maximum number of iterations for logistic regression. */
  private int m_maxIts = -1;

  /** The seed to pass on to K-means */
  private int m_clusteringSeed = 1;

  /** The minimum standard deviation */
  private double m_minStdDev = 0.1;

  /** a ZeroR model in case no model can be built from the data */
  private Classifier m_ZeroR;

  /**
   * Returns a string describing this classifier
   * 
   * @return a description of the classifier suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String globalInfo() {
    return "Class that implements a normalized Gaussian radial basis"
      + "basis function network.\n"
      + "It uses the k-means clustering algorithm to provide the basis "
      + "functions and learns either a logistic regression (discrete "
      + "class problems) or linear regression (numeric class problems) "
      + "on top of that. Symmetric multivariate Gaussians are fit to "
      + "the data from each cluster. If the class is "
      + "nominal it uses the given number of clusters per class."
      + "It standardizes all numeric "
      + "attributes to zero mean and unit variance.";
  }

  /**
   * Returns default capabilities of the classifier, i.e., and "or" of Logistic
   * and LinearRegression.
   * 
   * @return the capabilities of this classifier
   * @see Logistic
   * @see LinearRegression
   */
  @Override
  public Capabilities getCapabilities() {
    Capabilities result = new Logistic().getCapabilities();
    result.or(new LinearRegression().getCapabilities());
    Capabilities classes = result.getClassCapabilities();
    result.and(new SimpleKMeans().getCapabilities());
    result.or(classes);
    return result;
  }

  /**
   * Builds the classifier
   * 
   * @param instances the training data
   * @throws Exception if the classifier could not be built successfully
   */
  @Override
  public void buildClassifier(Instances instances) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    // only class? -> build ZeroR model
    if (instances.numAttributes() == 1) {
      System.err
        .println("Cannot build model (only class attribute present in data!), "
          + "using ZeroR model instead!");
      m_ZeroR = new weka.classifiers.rules.ZeroR();
      m_ZeroR.buildClassifier(instances);
      return;
    } else {
      m_ZeroR = null;
    }

    m_standardize = new Standardize();
    m_standardize.setInputFormat(instances);
    instances = Filter.useFilter(instances, m_standardize);

    SimpleKMeans sk = new SimpleKMeans();
    sk.setNumClusters(m_numClusters);
    sk.setSeed(m_clusteringSeed);
    MakeDensityBasedClusterer dc = new MakeDensityBasedClusterer();
    dc.setClusterer(sk);
    dc.setMinStdDev(m_minStdDev);
    m_basisFilter = new ClusterMembership();
    m_basisFilter.setDensityBasedClusterer(dc);
    m_basisFilter.setInputFormat(instances);
    Instances transformed = Filter.useFilter(instances, m_basisFilter);

    if (instances.classAttribute().isNominal()) {
      m_linear = null;
      m_logistic = new Logistic();
      m_logistic.setRidge(m_ridge);
      m_logistic.setMaxIts(m_maxIts);
      m_logistic.buildClassifier(transformed);
    } else {
      m_logistic = null;
      m_linear = new LinearRegression();
      m_linear.setAttributeSelectionMethod(new SelectedTag(
        LinearRegression.SELECTION_NONE, LinearRegression.TAGS_SELECTION));
      m_linear.setRidge(m_ridge);
      m_linear.buildClassifier(transformed);
    }
  }

  /**
   * Computes the distribution for a given instance
   * 
   * @param instance the instance for which distribution is computed
   * @return the distribution
   * @throws Exception if the distribution can't be computed successfully
   */
  @Override
  public double[] distributionForInstance(Instance instance) throws Exception {

    // default model?
    if (m_ZeroR != null) {
      return m_ZeroR.distributionForInstance(instance);
    }

    m_standardize.input(instance);
    m_basisFilter.input(m_standardize.output());
    Instance transformed = m_basisFilter.output();

    return ((instance.classAttribute().isNominal() ? m_logistic
      .distributionForInstance(transformed) : m_linear
      .distributionForInstance(transformed)));
  }

  /**
   * Returns a description of this classifier as a String
   * 
   * @return a description of this classifier
   */
  @Override
  public String toString() {

    // only ZeroR model?
    if (m_ZeroR != null) {
      StringBuffer buf = new StringBuffer();
      buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");
      buf.append(this.getClass().getName().replaceAll(".*\\.", "")
        .replaceAll(".", "=")
        + "\n\n");
      buf
        .append("Warning: No model could be built, hence ZeroR model is used:\n\n");
      buf.append(m_ZeroR.toString());
      return buf.toString();
    }

    if (m_basisFilter == null) {
      return "No classifier built yet!";
    }

    StringBuffer sb = new StringBuffer();
    sb.append("Radial basis function network\n");
    sb.append((m_linear == null) ? "(Logistic regression "
      : "(Linear regression ");
    sb.append("applied to K-means clusters as basis functions):\n\n");
    sb.append((m_linear == null) ? m_logistic.toString() : m_linear.toString());
    return sb.toString();
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String maxItsTipText() {
    return "Maximum number of iterations for the logistic regression to perform. "
      + "Only applied to discrete class problems.";
  }

  /**
   * Get the value of MaxIts.
   * 
   * @return Value of MaxIts.
   */
  public int getMaxIts() {

    return m_maxIts;
  }

  /**
   * Set the value of MaxIts.
   * 
   * @param newMaxIts Value to assign to MaxIts.
   */
  public void setMaxIts(int newMaxIts) {

    m_maxIts = newMaxIts;
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String ridgeTipText() {
    return "Set the Ridge value for the logistic or linear regression.";
  }

  /**
   * Sets the ridge value for logistic or linear regression.
   * 
   * @param ridge the ridge
   */
  public void setRidge(double ridge) {
    m_ridge = ridge;
  }

  /**
   * Gets the ridge value.
   * 
   * @return the ridge
   */
  public double getRidge() {
    return m_ridge;
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String numClustersTipText() {
    return "The number of clusters for K-Means to generate.";
  }

  /**
   * Set the number of clusters for K-means to generate.
   * 
   * @param numClusters the number of clusters to generate.
   */
  public void setNumClusters(int numClusters) {
    if (numClusters > 0) {
      m_numClusters = numClusters;
    }
  }

  /**
   * Return the number of clusters to generate.
   * 
   * @return the number of clusters to generate.
   */
  public int getNumClusters() {
    return m_numClusters;
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String clusteringSeedTipText() {
    return "The random seed to pass on to K-means.";
  }

  /**
   * Set the random seed to be passed on to K-means.
   * 
   * @param seed a seed value.
   */
  public void setClusteringSeed(int seed) {
    m_clusteringSeed = seed;
  }

  /**
   * Get the random seed used by K-means.
   * 
   * @return the seed value.
   */
  public int getClusteringSeed() {
    return m_clusteringSeed;
  }

  /**
   * Returns the tip text for this property
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String minStdDevTipText() {
    return "Sets the minimum standard deviation for the clusters.";
  }

  /**
   * Get the MinStdDev value.
   * 
   * @return the MinStdDev value.
   */
  public double getMinStdDev() {
    return m_minStdDev;
  }

  /**
   * Set the MinStdDev value.
   * 
   * @param newMinStdDev The new MinStdDev value.
   */
  public void setMinStdDev(double newMinStdDev) {
    m_minStdDev = newMinStdDev;
  }

  /**
   * Returns an enumeration describing the available options
   * 
   * @return an enumeration of all the available options
   */
  @Override
  public Enumeration<Option> listOptions() {
    Vector<Option> newVector = new Vector<Option>(5);

    newVector.addElement(new Option(
      "\tSet the number of clusters (basis functions) "
        + "to generate. (default = 2).", "B", 1, "-B <number>"));
    newVector.addElement(new Option(
      "\tSet the random seed to be used by K-means. " + "(default = 1).", "S",
      1, "-S <seed>"));
    newVector.addElement(new Option(
      "\tSet the ridge value for the logistic or " + "linear regression.", "R",
      1, "-R <ridge>"));
    newVector.addElement(new Option("\tSet the maximum number of iterations "
      + "for the logistic regression." + " (default -1, until convergence).",
      "M", 1, "-M <number>"));
    newVector.addElement(new Option("\tSet the minimum standard "
      + "deviation for the clusters." + " (default 0.1).", "W", 1,
      "-W <number>"));

    newVector.addAll(Collections.list(super.listOptions()));

    return newVector.elements();
  }

  /**
   * Parses a given list of options.
   * <p/>
   * 
   * <!-- options-start --> Valid options are:
   * <p/>
   * 
   * <pre>
   * -B &lt;number&gt;
   *  Set the number of clusters (basis functions) to generate. (default = 2).
   * </pre>
   * 
   * <pre>
   * -S &lt;seed&gt;
   *  Set the random seed to be used by K-means. (default = 1).
   * </pre>
   * 
   * <pre>
   * -R &lt;ridge&gt;
   *  Set the ridge value for the logistic or linear regression.
   * </pre>
   * 
   * <pre>
   * -M &lt;number&gt;
   *  Set the maximum number of iterations for the logistic regression. (default -1, until convergence).
   * </pre>
   * 
   * <pre>
   * -W &lt;number&gt;
   *  Set the minimum standard deviation for the clusters. (default 0.1).
   * </pre>
   * 
   * <!-- options-end -->
   * 
   * @param options the list of options as an array of strings
   * @throws Exception if an option is not supported
   */
  @Override
  public void setOptions(String[] options) throws Exception {
    setDebug(Utils.getFlag('D', options));

    String ridgeString = Utils.getOption('R', options);
    if (ridgeString.length() != 0) {
      m_ridge = Double.parseDouble(ridgeString);
    } else {
      m_ridge = 1.0e-8;
    }

    String maxItsString = Utils.getOption('M', options);
    if (maxItsString.length() != 0) {
      m_maxIts = Integer.parseInt(maxItsString);
    } else {
      m_maxIts = -1;
    }

    String numClustersString = Utils.getOption('B', options);
    if (numClustersString.length() != 0) {
      setNumClusters(Integer.parseInt(numClustersString));
    }

    String seedString = Utils.getOption('S', options);
    if (seedString.length() != 0) {
      setClusteringSeed(Integer.parseInt(seedString));
    }
    String stdString = Utils.getOption('W', options);
    if (stdString.length() != 0) {
      setMinStdDev(Double.parseDouble(stdString));
    }

    super.setOptions(options);

    Utils.checkForRemainingOptions(options);
  }

  /**
   * Gets the current settings of the classifier.
   * 
   * @return an array of strings suitable for passing to setOptions
   */
  @Override
  public String[] getOptions() {

    Vector<String> options = new Vector<String>();

    options.add("-B");
    options.add("" + m_numClusters);
    options.add("-S");
    options.add("" + m_clusteringSeed);
    options.add("-R");
    options.add("" + m_ridge);
    options.add("-M");
    options.add("" + m_maxIts);
    options.add("-W");
    options.add("" + m_minStdDev);

    Collections.addAll(options, super.getOptions());

    return options.toArray(new String[0]);
  }

  /**
   * Returns the revision string.
   * 
   * @return the revision
   */
  @Override
  public String getRevision() {
    return RevisionUtils.extract("$Revision$");
  }

  /**
   * Main method for testing this class.
   * 
   * @param argv should contain the command line arguments to the scheme (see
   *          Evaluation)
   */
  public static void main(String[] argv) {
    runClassifier(new RBFNetwork(), argv);
  }
}
